
cd('.')
addpath('matlab functions')

D0 = readtable('NC Data/SCC_4to5.csv');

numsim = 1000;

for sbj = {'math','reading'}
    
    vamdl = load(sprintf('output/main/vam_%s_matching.mat',sbj{1}));
    tchcoef = vamdl.modspecs.tchcoef;
    lvl = startsWith(tchcoef,'yr');
    ValueAddedEstimates = [vamdl.ValueAddedEstimates];

    D = D0(~isnan(D0.(vamdl.modspecs.outcome)),:);
    [~,D.tch_idx] = ismember(D.teacherid,[ValueAddedEstimates.teacherid]);        
    assert(all(D.tch_idx~=0))

    for policy_sim = 1:2

        if policy_sim==1
            grpvars = {'year','districtid'};
        elseif policy_sim==2
            grpvars = {'year','districtid','schoolid'};
        end    

        [~,~,D.grp] = unique(D(:,grpvars));        
        numGRP = max(D.grp);
        TBL = cell([numGRP 1]);

        for i = 1:numGRP

            fprintf('Computing %d of %d \n',i,numGRP)
            X = D(D.grp==i,:);
            [tch_idx,~,j] = unique(X.tch_idx);

            va_hat = cat(2,ValueAddedEstimates(tch_idx).vahat)';
            va_hat = va_hat(:,~lvl);

            va_dist = cellfun(@(x) gmdistribution(x.mean',x.cov,x.pr) ...
                ,{ValueAddedEstimates(tch_idx).vadist},'unif',0);
            va_sim = cellfun(@(gm) random(gm,numsim),va_dist,'unif',0);
            va_sim = cat(3,va_sim{:});
            va_sim = permute(va_sim,[3 2 1]);
            va_sim = va_sim(:,~lvl,:);

            Zs = X{:,tchcoef(~lvl)};
            Z = groupsummary(Zs,j,'sum');

            X = X(:,{'grp' 'teacherid' sprintf('lag_%sscore',sbj{1}) 'female' 'race_black'});

            VA = va_hat*Z';
            reallo = matchpairs(VA,-realmax,'max');
            X.VA_gain_b = sum(va_hat(reallo(j,1),:).*Zs,2) - sum(va_hat(j,:).*Zs,2);

            VA_gain = nan([size(Zs,1) numsim]);

            for s = 1:numsim
                
                VA = va_sim(:,:,s)*Z';
                reallo = matchpairs(VA,-realmax,'max');
                VA_gain(:,s) = sum(va_sim(reallo(j,1),:,s).*Zs,2) - sum(va_sim(j,:,s).*Zs,2);
                
            end

            X.VA_gain = mean(VA_gain,2);
    
            TBL{i} = X;

        end

        TBL = cat(1,TBL{:});
        save(sprintf('output/reallocation/%s_%d_matching.mat',sbj{1},policy_sim),'TBL')

    end

end




